{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 基于昇思MindSpore+Orangepi AIpro的CNNCTC文字识别"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "文本识别指从图像中识别出文本，将图像中的文字区域转化为字符信息，通常采用CNN网络从图像中提取丰富的特征信息，然后根据提取的特征信息进行识别。这里采用ResNet作为特征提取网络，采用CTC(Connectionist Temporal Classification)方法进行识别。\n",
    "此脚本将**基于昇思MindSpore框架开发的cnnctc模型训练出来的权重文件转换成OM文件**后，调用ACL相关接口进行离线推理。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 前期准备"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* 基础镜像的样例目录中已包含转换后的om模型以及测试图片，如果直接运行，可跳过此步骤。如果需要重新转换模型，可以参考下面的步骤。\n",
    "* **建议在Linux服务器或者虚拟机转换该模型。**\n",
    "* **为了能进一步优化模型推理性能，我们需要将其转换为om模型进行使用；转换指导详见全流程实验指导。**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 环境配置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!sh env.sh"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型推理实现 ##\n",
    "得到cnnctc.om后，执行离线推理代码,加载图片predict.png\n",
    "\n",
    "![jupyter](./predict.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 导入三方库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import time\n",
    "import argparse\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "\n",
    "from acllite_model import AclLiteModel as Model\n",
    "from acllite_resource import AclLiteResource as AclResource"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 模型导入和处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取模型om文件\n",
    "from download import download\n",
    "model_url = \"https://modelers.cn/coderepo/web/v1/file/MindSpore-Lab/cluoud_obs/main/media/examples/mindspore-courses/orange-pi-mindspore/02-CNNCTC/cnnctc.zip\"\n",
    "download(model_url, \"./\", kind=\"zip\", replace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init resource stage:\n",
      "Init resource success\n",
      "load model....\n",
      "Init model resource start...\n",
      "[AclLiteModel] create model output dataset:\n",
      "malloc output 0, size 3848\n",
      "Create model output dataset success\n",
      "Init model resource success\n",
      "load model finished....\n"
     ]
    }
   ],
   "source": [
    "# om模型和图片的位置\n",
    "MODEL_PATH = './cnnctc.om'\n",
    "IMAGE_PATH = './predict.png'\n",
    "\n",
    "# 初始化acl资源\n",
    "acl_resource = AclResource()\n",
    "acl_resource.init()\n",
    "\n",
    "#导入本地om模型\n",
    "print('load model....')\n",
    "model = Model(MODEL_PATH)\n",
    "print('load model finished....')\n",
    "\n",
    "# 文本与数据编码\n",
    "class CTCLabelConverter():\n",
    "    def __init__(self, character):\n",
    "        dict_character = list(character)\n",
    "        self.dict = {}\n",
    "        for i, char in enumerate(dict_character):\n",
    "            self.dict[char] = i + 1\n",
    "        self.character = ['[blank]'] + dict_character\n",
    "        self.dict['[blank]'] = 0\n",
    "\n",
    "    #将文本转换为数字编码\n",
    "    def encode(self, text):\n",
    "        length = [len(s) for s in text]\n",
    "        text = ''.join(text)\n",
    "        text = [self.dict[char] for char in text]\n",
    "\n",
    "        return np.array(text), np.array(length)\n",
    "\n",
    "    # 将数字编码转换为文本\n",
    "    def decode(self, text_index, length):\n",
    "        texts = []\n",
    "        index = 0\n",
    "        for l in length:\n",
    "            t = text_index[index:index + l]\n",
    "            char_list = []\n",
    "            for i in range(l):\n",
    "                if t[i] != self.dict['[blank]'] and (\n",
    "                        not (i > 0 and t[i - 1] == t[i])):\n",
    "                    char_list.append(self.character[t[i]])\n",
    "            text = ''.join(char_list)\n",
    "            texts.append(text)\n",
    "            index += l\n",
    "        return texts\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 进行推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "infer use time:8.622884750366211ms\n",
      "Predict:  ['parking']\n"
     ]
    }
   ],
   "source": [
    "# 导入和处理目标图片\n",
    "img_PIL = Image.open(IMAGE_PATH).convert('RGB')\n",
    "img = img_PIL.resize((100, 32), resample=3)\n",
    "img = np.array(img, dtype=np.float32)\n",
    "img = np.expand_dims(img, axis=0) \n",
    "img = np.transpose(img, [0, 3, 1, 2]) \n",
    "\n",
    "# 定义推理的时间\n",
    "start = time.time()\n",
    "model_predict = model.execute([img])[0]\n",
    "end = time.time()\n",
    "print(f'infer use time:{(end-start)*1000}ms')\n",
    "\n",
    "# 初始化文本编码函数\n",
    "character = '0123456789abcdefghijklmnopqrstuvwxyz'\n",
    "converter = CTCLabelConverter(character)\n",
    "\n",
    "# 推理过程\n",
    "preds_size = np.array([model_predict.shape[1]])\n",
    "preds_index = np.argmax(model_predict, 2)\n",
    "preds_index = np.reshape(preds_index, [-1])\n",
    "preds_str = converter.decode(preds_index, preds_size)\n",
    "print('Predict: ', preds_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结与扩展 ##"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上就是cnnctc文本识别样例离线推理的运行结果了，可以看到最后的验证结果，成功识别了示例图片中‘PARKING’的字样。\n",
    "注意：\n",
    "1. 若出现推理失败的情况，请确保以root权限设置好环境变量（运行或参考文件夹内的env.sh文件）。\n",
    "2. 再次进行推理清清空所有缓存。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
